%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Counterfactual 1: keep C at baseline estimate
% but change persistence
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

[M_full,var_M_full,W_full,X_monthly,random_state]     ...
                                        = data_moments(project_folder,'persist');
setup_inputs.N                          = 15;
setup_inputs.mean_reversion             = 90;
setup_inputs.xi                         = 0.0117;
setup_inputs.sigma                      = 0.043803490552968;
setup_inputs.S                          = 0.327511125;
setup_inputs.M_e                        = 0.965950533806254; 
setup_inputs.C                          = 0.0699930625;
setup_inputs.X0_mean                    = 0.005;
setup_inputs_saved                      = setup_inputs;
pars_saved.k                            = 0.250;
N_mr                                    = 10;
mean_reversion_vec                      = linspace(90,360,N_mr)';
scaling                                 = 0*linspace(0.95,0.40,N_mr)';

for n=1:numel(mean_reversion_vec)

    % Model with complementarities
    setup_inputs.mean_reversion                 = mean_reversion_vec(n);
    setup_inputs.k                              = pars_saved.k;
    [pars_baseline{n},Nu_baseline{n}]           = setup_function(setup_inputs,random_state,'persist');
    baseline{n}                                 = evaluate(pars_baseline{n},M_full,W_full,Nu_baseline{n},0);
    setup_inputs                                = setup_inputs_saved;

    % Model without complementarities
    setup_inputs.mean_reversion                 = mean_reversion_vec(n);
    setup_inputs.k                              = pars_saved.k;
    setup_inputs.InitX                          = 0;
    setup_inputs.C                              = scaling(n)*pars_baseline{n}.C;
    [pars_cf{n},Nu_cf{n}]                       = setup_function(setup_inputs,random_state,'persist');
    cf{n}                                       = evaluate(pars_cf{n},M_full,W_full,Nu_cf{n},0);
    setup_inputs                                = setup_inputs_saved;

end

% Format data
IRF_m8_baseline                         = NaN(size(mean_reversion_vec));
IRF_m8_cf                               = NaN(size(mean_reversion_vec));
for n=1:N_mr
        IRF_m8_baseline(n)                      = mean(baseline{n}.OutputFinal.M(8,:),2);
        IRF_m8_cf(n)                            = mean(cf{n}.OutputFinal.M(8,:),2);
end
G_m8_noC                                = smooth(100*(1-IRF_m8_cf./IRF_m8_baseline));
G_m8_noC                                = G_m8_noC/G_m8_noC(1);
clearvars -except project_folder random_state M_full W_full G_m8_noC

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Counterfactual 2: re-estimate C
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% Estimation --- about 4h runtime on Mac laptop
Nk                                                                      = 1e1+1;
mean_reversion_vec                                                      = linspace(90,360,Nk)';
for nk=1:Nk
    fprintf('Estimation for mr = %.0f\n',mean_reversion_vec(nk));
    fprintf('***********************\n');
    fprintf('\n');
    [outTable{nk},~,pars{nk},Nu{nk}]                               = estimate_mr(mean_reversion_vec(nk),M_full,W_full,random_state);
    disp(outTable{nk}(1,:))
end

% Counterfactuals
for n=1:numel(mean_reversion_vec)

    % Model with complementarities
    setup_inputs.targeted_pars                  = {'S','M_e','C'};
    setup_inputs.mean_reversion                 = mean_reversion_vec(n);
    setup_inputs.k                              = str2num(outTable{n}{1,1});
    setup_inputs.xi                             = str2num(outTable{n}{1,3});
    setup_inputs.sigma                          = str2num(outTable{n}{1,4});
    setup_inputs.S                              = str2num(outTable{n}{1,5});
    setup_inputs.M_e                            = str2num(outTable{n}{1,6});
    setup_inputs.C                              = str2num(outTable{n}{1,7});
    [pars_baseline{n},Nu_baseline{n}]           = setup_function(setup_inputs,random_state,'persist');
    baseline{n}                                 = evaluate(pars_baseline{n},M_full,W_full,Nu_baseline{n},0);

    % Model without complementarities
    setup_inputs.targeted_pars                  = {'S','M_e','C'};
    setup_inputs.mean_reversion                 = mean_reversion_vec(n);
    setup_inputs.k                              = str2num(outTable{n}{1,1});
    setup_inputs.xi                             = str2num(outTable{n}{1,3});
    setup_inputs.sigma                          = str2num(outTable{n}{1,4});
    setup_inputs.S                              = str2num(outTable{n}{1,5});
    setup_inputs.M_e                            = str2num(outTable{n}{1,6});
    setup_inputs.C                              = 0;
    setup_inputs.InitX                          = zeros(Nu_baseline{n}.Nsim,1);
    [pars_cf{n},Nu_cf{n}]                       = setup_function(setup_inputs,random_state,'persist');
    cf{n}                                       = evaluate(pars_cf{n},M_full,W_full,Nu_cf{n},0);

end

% Format data
IRF_m8_baseline                         = NaN(size(mean_reversion_vec));
IRF_m8_cf                               = NaN(size(mean_reversion_vec));
for n=1:numel(mean_reversion_vec)

        IRF_m8_baseline(n)                      = mean(baseline{n}.OutputFinal.M(8,:),2);
        IRF_m8_cf(n)                            = mean(cf{n}.OutputFinal.M(8,:),2);

end
IRF_m8_cf                               = interp1(mean_reversion_vec(1:(end-1)),IRF_m8_cf(2:end),mean_reversion_vec,'linear','extrap');
IRF_m8_baseline                         = interp1(mean_reversion_vec(1:(end-1)),IRF_m8_baseline(2:end),mean_reversion_vec,'linear','extrap');
G_m8                                    = 100*(1-IRF_m8_cf./IRF_m8_baseline);
mr_grid                                 = linspace(min(mean_reversion_vec),max(mean_reversion_vec),1e1)';
G_m8_interp                             = smooth(interp1(mean_reversion_vec,G_m8,mr_grid));
G_m8_interp                             = G_m8_interp/G_m8_interp(1);
G_m8_C                                  = G_m8_interp;

%%%%%%%
%%%%%%%
% Plots
%%%%%%%
%%%%%%%

% Plot
fig_pers                                = figure('Name','Figure: effect of persistence','NumberTitle','off');
plot(mr_grid,G_m8_noC,'-xr','LineWidth',2); hold on;
plot(mr_grid,G_m8_C,'-ob','LineWidth',2); hold on;
ylabel('\%','interpreter','latex','rotation',0);
xlabel('(Shock persistence, in days)','interpreter','latex','rotation',0);
ll                                      =legend('Fixing $\hat{C} = \hat{C}_{baseline}$','Re-estimating $\hat{C}$');
set(ll,'location','Northeast','interpreter','Latex');
ax                                      = gca;
ax.XLim                                 = [90 240];
ax.YLim                                 = [0 60];
ax.XTick                                = mr_grid;
grid on;
tt                                      = title('Fraction of the 8-month response explained by complementarities','interpreter','latex');
set(fig_pers, 'Units', 'Normalized', 'OuterPosition', [0, 0.04, 0.4, 0.4]);    
set(fig_pers,'Units','inches');
screenposition                          = get(fig_pers,'Position');
set(fig_pers,'PaperPosition',[0 0 screenposition(3:4)],'PaperSize',[screenposition(3:4)]);
fig_pers.PaperPositionMode              ='auto';
print(fig_pers,[root_folder '/output/figures/Figure_H_17.pdf'],'-dpdf','-fillpage')
